"""
Phase 5: Projections — Country Risk Timeline & "Who's Next?"
==============================================================
1. Country risk timeline using estimated coefficients + UN WPP projections
2. "Who's next?" ranking
3. KAOPEN scenario analysis
4. GE overlay (global simultaneous Japanification)

Input:  japanification/data/processed/japan_panel_indexed.csv
        multilateral/data/processed/full_panel.csv (for future Z polynomials)
Output: japanification/output/tables/phase5_*.csv
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
JAPAN_DIR = PROJECT_DIR / "japanification"
MULTILATERAL_DIR = PROJECT_DIR / "multilateral"
PROCESSED_DIR = JAPAN_DIR / "data" / "processed"
TABLE_DIR = JAPAN_DIR / "output" / "tables"

sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS


def main():
    print("=" * 70)
    print("PHASE 5: Projections — Country Risk Timeline")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "japan_panel_indexed.csv")
    print(f"Loaded panel: {len(df):,} obs, {df['iso3'].nunique()} countries")

    # Load full panel for future Z polynomials (extends beyond 2024)
    fp = pd.read_csv(MULTILATERAL_DIR / "followup" / "data" / "processed" / "full_panel.csv")
    print(f"Loaded full panel for projections: {fp.shape[0]:,} rows, "
          f"{fp['year'].min()}-{fp['year'].max()}")

    dep_var = 'japan_index_2c'
    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'kaopen', 'log_rel_opw', 'nfa_gdp_lag']
    all_vars = demo_vars + controls

    # =================================================================
    # 0. Estimate baseline model for projection coefficients
    # =================================================================
    est = df.dropna(subset=[dep_var] + all_vars).copy()
    model = PanelGLS()
    model.fit(est[dep_var].values, est[all_vars].values,
              est['iso3'].values, est['year'].values)
    model.summary(feature_names=all_vars)

    z_coefs = {all_vars[i]: model.beta[i] for i in range(len(all_vars))}

    # Also estimate KAOPEN interaction model
    int_vars = ['Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen']
    avail_int = [v for v in int_vars if v in df.columns]
    all_vars_int = all_vars + avail_int
    est_int = df.dropna(subset=[dep_var] + all_vars_int).copy()
    model_int = PanelGLS()
    model_int.fit(est_int[dep_var].values, est_int[all_vars_int].values,
                  est_int['iso3'].values, est_int['year'].values)
    int_coefs = {all_vars_int[i]: model_int.beta[i] for i in range(len(all_vars_int))}

    # =================================================================
    # 1. Country risk timeline
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  1. Country Risk Timeline")
    print("=" * 70)

    # Get Japan's index level circa 2000 as the "onset threshold"
    jpn_2000 = df[(df['iso3'] == 'JPN') & (df['year'] == 2000)]
    if len(jpn_2000) > 0:
        japan_threshold = jpn_2000[dep_var].values[0]
    else:
        japan_threshold = df[df['iso3'] == 'JPN'][dep_var].mean()
    print(f"  Japan onset threshold (Japan ~2000): {japan_threshold:.3f}")

    # Focus countries
    focus_countries = [
        'JPN', 'DEU', 'ITA', 'GRC', 'KOR', 'CHN', 'THA', 'ESP',
        'USA', 'GBR', 'FRA', 'CAN', 'AUS', 'BRA', 'IND', 'IDN',
        'MEX', 'ZAF', 'NGA', 'SAU', 'TUR', 'POL', 'RUS', 'SGP',
    ]

    proj_years = [2000, 2010, 2020, 2030, 2040, 2050, 2060]

    # Get last observed control values for each country (hold constant for projections)
    last_controls = (df.dropna(subset=controls)
                     .sort_values('year')
                     .groupby('iso3')[controls]
                     .last())

    # Compute projected demographic contribution for each country-year
    proj_rows = []
    for iso3 in focus_countries:
        cdata = fp[fp['iso3'] == iso3]
        ctrl = last_controls.loc[iso3] if iso3 in last_controls.index else None

        if ctrl is None or len(cdata) == 0:
            continue

        for year in proj_years:
            yr = cdata[cdata['year'] == year]
            if len(yr) == 0:
                continue

            # Demographic contribution only (Z terms)
            demo_effect = sum(z_coefs[zv] * yr[zv].values[0] for zv in demo_vars)

            # Full prediction (demo + controls held at last observed)
            control_effect = sum(z_coefs[cv] * ctrl[cv] for cv in controls)
            total = demo_effect + control_effect + model.constant

            proj_rows.append({
                'iso3': iso3,
                'year': year,
                'demo_contribution': demo_effect,
                'japan_index_projected': total,
                'exceeds_threshold': total > japan_threshold,
            })

    proj_df = pd.DataFrame(proj_rows)

    # Compute crossing year for each country
    crossing_rows = []
    for iso3 in focus_countries:
        cproj = proj_df[proj_df['iso3'] == iso3].sort_values('year')
        if len(cproj) == 0:
            continue

        # Find first year exceeding threshold
        above = cproj[cproj['japan_index_projected'] > japan_threshold]
        crossing_year = above['year'].min() if len(above) > 0 else np.nan

        # Current level
        current = cproj[cproj['year'] == 2020]
        current_idx = current['japan_index_projected'].values[0] if len(current) > 0 else np.nan

        # Pivot for decade values
        row = {'iso3': iso3, 'current_index': current_idx, 'crossing_year': crossing_year}
        for yr in [2030, 2040, 2050]:
            val = cproj.loc[cproj['year'] == yr, 'japan_index_projected']
            row[f'proj_{yr}'] = val.values[0] if len(val) > 0 else np.nan

        crossing_rows.append(row)

    crossing_df = pd.DataFrame(crossing_rows)
    crossing_df = crossing_df.sort_values('crossing_year')

    print("\n  Country Japanification Timeline:")
    print(crossing_df.to_string(index=False, float_format='%.3f'))
    crossing_df.to_csv(TABLE_DIR / "phase5_country_timeline.csv", index=False)

    # =================================================================
    # 2. "Who's next?" ranking
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  2. Who's Next? — Current Japanification Risk Ranking")
    print("=" * 70)

    # Compute current (2020-2024 average) index for all countries
    recent = df[(df['year'] >= 2020) & (df['year'] <= 2024)]
    country_current = (recent.groupby('iso3')[dep_var]
                       .agg(['mean', 'std', 'count'])
                       .rename(columns={'mean': 'current_mean', 'std': 'current_std',
                                        'count': 'n_years'})
                       .sort_values('current_mean', ascending=False))

    # Add OADR for context
    oadr_recent = recent.groupby('iso3')['old_dep'].mean()
    country_current['oadr_mean'] = oadr_recent

    print("\n  Top 20 Countries by Current Japanification Index:")
    print(country_current.head(20).to_string(float_format='%.3f'))
    country_current.to_csv(TABLE_DIR / "phase5_whos_next_ranking.csv")

    # Emerging risks: fastest approach speed (change in index over last decade)
    early = df[(df['year'] >= 2010) & (df['year'] <= 2014)]
    late = df[(df['year'] >= 2020) & (df['year'] <= 2024)]
    early_avg = early.groupby('iso3')[dep_var].mean()
    late_avg = late.groupby('iso3')[dep_var].mean()
    delta = (late_avg - early_avg).dropna().sort_values(ascending=False)
    delta_df = delta.reset_index()
    delta_df.columns = ['iso3', 'japan_index_change_10yr']

    print("\n  Top 15 Fastest Japanification (2010s→2020s change):")
    print(delta_df.head(15).to_string(index=False, float_format='%.3f'))
    delta_df.to_csv(TABLE_DIR / "phase5_fastest_japanification.csv", index=False)

    # =================================================================
    # 3. KAOPEN scenario analysis
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  3. KAOPEN Scenario Analysis")
    print("=" * 70)

    # For focus countries, compare projections under:
    # (a) current KAOPEN  (b) fully open (KAOPEN=2.4)  (c) closed (KAOPEN=-1.9)
    scenario_rows = []
    for iso3 in ['JPN', 'DEU', 'KOR', 'CHN', 'USA', 'IND', 'BRA', 'GBR']:
        cdata = fp[(fp['iso3'] == iso3) & (fp['year'] == 2040)]
        ctrl = last_controls.loc[iso3] if iso3 in last_controls.index else None

        if ctrl is None or len(cdata) == 0:
            continue

        z_vals = {zv: cdata[zv].values[0] for zv in demo_vars}

        for scenario, kaopen_val in [('Current', ctrl['kaopen']),
                                      ('Fully Open', 2.4),
                                      ('Closed', -1.9)]:
            demo_eff = sum(int_coefs[zv] * z_vals[zv] for zv in demo_vars)
            ctrl_eff = sum(int_coefs[cv] * (kaopen_val if cv == 'kaopen' else ctrl[cv])
                           for cv in controls)
            # KAOPEN interaction effect
            int_eff = sum(int_coefs.get(f'{zv}_x_kaopen', 0) * z_vals[zv] * kaopen_val
                          for zv in demo_vars)
            total = demo_eff + ctrl_eff + int_eff + model_int.constant

            scenario_rows.append({
                'iso3': iso3,
                'scenario': scenario,
                'kaopen': kaopen_val,
                'japan_index_2040': total,
            })

    scenario_df = pd.DataFrame(scenario_rows)
    if len(scenario_df) > 0:
        pivot = scenario_df.pivot(index='iso3', columns='scenario', values='japan_index_2040')
        pivot = pivot[['Current', 'Fully Open', 'Closed']].sort_values('Current', ascending=False)
        print("\n  Japanification Index 2040 Under KAOPEN Scenarios:")
        print(pivot.to_string(float_format='%.3f'))
        pivot.to_csv(TABLE_DIR / "phase5_kaopen_scenarios.csv")

    # =================================================================
    # 4. GE overlay: global simultaneous Japanification
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  4. GE Overlay — Simultaneous Global Japanification")
    print("=" * 70)

    # Count countries above threshold by decade
    ge_rows = []
    for year in [2000, 2010, 2020, 2030, 2040, 2050]:
        yr_proj = proj_df[proj_df['year'] == year]
        n_above = yr_proj['exceeds_threshold'].sum()
        n_total = len(yr_proj)
        mean_idx = yr_proj['japan_index_projected'].mean()
        ge_rows.append({
            'year': year,
            'n_above_threshold': n_above,
            'n_countries': n_total,
            'share_above': n_above / n_total if n_total > 0 else 0,
            'mean_japan_index': mean_idx,
        })

    ge_df = pd.DataFrame(ge_rows)
    print("\n  Global Japanification Trajectory (focus countries):")
    print(ge_df.to_string(index=False, float_format='%.3f'))
    ge_df.to_csv(TABLE_DIR / "phase5_ge_trajectory.csv", index=False)

    # GDP-weighted Japanification index
    print("\n  Note: If many large economies Japanify simultaneously,")
    print("  global rate compression accelerates, reinforcing the cycle.")
    print("  Full GE clearing analysis available in multilateral/clearing_channels/.")

    print(f"\n{'=' * 70}")
    print(f"Phase 5 complete. Tables saved to {TABLE_DIR}")
    print("=" * 70)


if __name__ == "__main__":
    main()
